'''
Author: SlytherinGe
LastEditTime: 2021-08-17 10:16:22
'''
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time
import os

def os_cfar(img, guard_cells, bg_cells, alpha, padding=True):
    '''
    img: [h, w, c],21021110370
    '''
    im_h, im_w, im_c = img.shape
    pad_size = (guard_cells + bg_cells) if padding else 0
    cfar_unit_size = 1 + guard_cells * 2 + bg_cells * 2

    padded_img = np.ones((im_h + 2 * pad_size,
                            im_w + 2 * pad_size,
                            im_c), np.uint8)
    padded_img[pad_size:-pad_size, pad_size:-pad_size] = img

    cfar_units = np.zeros((im_h, im_w, cfar_unit_size * cfar_unit_size))

    for i in range(im_h):
        for j in range(im_w):
            cfar_unit = padded_img[i:i+cfar_unit_size,j:j+cfar_unit_size,0].copy()
            cfar_unit[bg_cells:bg_cells+2*guard_cells+1, bg_cells:bg_cells+2*guard_cells+1] = 0
            cfar_units[i,j] = cfar_unit.flatten()
    
    num_bg = cfar_unit_size * cfar_unit_size - (2*guard_cells+ 1)*(2*guard_cells+ 1)
    cfar_units = np.sort(cfar_units, axis=2)[:,:,-num_bg-1:-1]
    cfar_units = cfar_units[:,:,int(num_bg*0.75)]
    results = img[:,:,0] > (cfar_units * alpha)
    results = results * 255

    return results

def bbox_guided_os_cfar(img, bbox, bg_cells, alpha, padding=True):
    '''
    img: [h, w, c]
    '''
    im_h, im_w, im_c = img.shape

    xmin, ymin, xmax, ymax = int(bbox[0]), int(bbox[1]), int(bbox[2]),int(bbox[3])
    guard_cells = int(min(xmax - xmin, ymax - ymin)/2)
    pad_size = (guard_cells + bg_cells) if padding else 0
    cfar_unit_size = 1 + guard_cells * 2 + bg_cells * 2

    padded_img = np.ones((im_h + 2 * pad_size,
                            im_w + 2 * pad_size,
                            im_c), np.uint8)
    padded_img[pad_size:-pad_size, pad_size:-pad_size] = img

    num_bg = cfar_unit_size * cfar_unit_size - (2*guard_cells+ 1)*(2*guard_cells+ 1)

    cfar_units = np.ones((im_h, im_w, num_bg)) * 1024
    for i in range(ymin+pad_size, ymax+pad_size):
        for j in range(xmin+pad_size, xmax+pad_size):
            cfar_top = padded_img[i-guard_cells-bg_cells:i-guard_cells, 
                                j-guard_cells-bg_cells:j+guard_cells+bg_cells+1,0]
            cfar_button = padded_img[i+guard_cells+1:i+guard_cells+bg_cells+1,
                                     j-guard_cells-bg_cells:j+guard_cells+bg_cells+1,0]
            cfar_left = padded_img[i-guard_cells:i+guard_cells+1,
                                   j-guard_cells-bg_cells:j-guard_cells,0]
            cfar_right = padded_img[i-guard_cells:i+guard_cells+1,
                                    j+guard_cells+1:j+guard_cells+bg_cells+1,0]
            cfar_units[i-pad_size,j-pad_size] = np.concatenate([cfar_top.flatten(), cfar_button.flatten(), cfar_left.flatten(), cfar_right.flatten()])

    cfar_units = np.sort(cfar_units, axis=2)
    cfar_units = cfar_units[:,:,int(num_bg*0.75)]
    results = img[:,:,0] > (cfar_units * alpha)

    return results    

if __name__ == '__main__':

    IMG_ROOT = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/'
    OUT_DIR = '/media/gejunyao/Disk/Gejunyao/exp_results/visualization/results/cfar/os-cfar-ssdd/'

    imgs = os.listdir(IMG_ROOT)
    num_img = len(imgs)
    index = 1

    # bbox = [10,10,15,15]

    # for img in imgs:
    #     img_path = os.path.join(IMG_ROOT, img)
    #     out_path = os.path.join(OUT_DIR, img)
    #     t1 = time.time()
    #     im = cv2.imread(img_path)
    #     res = bbox_guided_os_cfar(im, bbox, 5, 2)
    #     cv2.imwrite(out_path, res)
    #     t2 = time.time()
    #     print('img {} finished, {}/{}, using {:.4f}s'.format(img, index, num_img, t2-t1))
    #     index += 1

    for img in imgs:
        img_path = os.path.join(IMG_ROOT, img)
        out_path = os.path.join(OUT_DIR, img)
        t1 = time.time()
        im = cv2.imread(img_path)
        res = os_cfar(im, 10, 5, 2)
        cv2.imwrite(out_path, res)
        t2 = time.time()
        print('img {} finished, {}/{}, using {:.4f}s'.format(img, index, num_img, t2-t1))
        index += 1
